{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.2 获取和读取数据集\n",
    "比赛数据分为训练数据集和测试数据集。两个数据集都包括每栋房子的特征，如街道类型、建造年份、房顶类型、地下室状况等特征值。这些特征值有连续的数字、离散的标签甚至是缺失值“na”。只有训练数据集包括了每栋房子的价格，也就是标签。我们可以访问比赛网页，点击图3.8中的“Data”标签，并下载这些数据集。\n",
    "\n",
    "我们将通过pandas库读入并处理数据。在导入本节需要的包前请确保已安装pandas库，否则请参考下面的代码注释。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n"
     ]
    }
   ],
   "source": [
    "# 如果没有安装pandas，则反注释下面一行\n",
    "# !pip install pandas\n",
    "\n",
    "%matplotlib inline\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow import initializers as init\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设解压后的数据位于./data/kaggle_house/目录，它包括两个csv文件。下面使用pandas读取这两个文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = pd.read_csv('../data/kaggle_house/train.csv')\n",
    "test_data = pd.read_csv('../data/kaggle_house/test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练数据集包括1460个样本、80个特征和1个标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1460, 81)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.shape # 输出 (1460, 81)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试数据集包括1459个样本和80个特征。我们需要将测试数据集中每个样本的标签预测出来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1459, 80)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data.shape # 输出 (1459, 80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们来查看前4个样本的前4个特征、后2个特征和标签（SalePrice）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Id</th>\n",
       "      <th>MSSubClass</th>\n",
       "      <th>MSZoning</th>\n",
       "      <th>LotFrontage</th>\n",
       "      <th>SaleType</th>\n",
       "      <th>SaleCondition</th>\n",
       "      <th>SalePrice</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>60</td>\n",
       "      <td>RL</td>\n",
       "      <td>65.0</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>208500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>RL</td>\n",
       "      <td>80.0</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>181500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>60</td>\n",
       "      <td>RL</td>\n",
       "      <td>68.0</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>223500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>70</td>\n",
       "      <td>RL</td>\n",
       "      <td>60.0</td>\n",
       "      <td>WD</td>\n",
       "      <td>Abnorml</td>\n",
       "      <td>140000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Id  MSSubClass MSZoning  LotFrontage SaleType SaleCondition  SalePrice\n",
       "0   1          60       RL         65.0       WD        Normal     208500\n",
       "1   2          20       RL         80.0       WD        Normal     181500\n",
       "2   3          60       RL         68.0       WD        Normal     223500\n",
       "3   4          70       RL         60.0       WD       Abnorml     140000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到第一个特征是Id，它能帮助模型记住每个训练样本，但难以推广到测试样本，所以我们不使用它来训练。我们将所有的训练数据和测试数据的79个特征按样本连结。train_data.iloc[:, 1:-1]列出从第一列到最后一列所有的样本，concat将样本按列连接。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.3. 预处理数据¶\n",
    "我们对连续数值的特征做标准化（standardization）：设该特征在整个数据集上的均值为μ，标准差为σ。那么，我们可以将该特征的每个值先减去μ再除以σ得到标准化后的每个特征值。对于缺失的特征值，我们将其替换成该特征的均值。!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#numeric_features表示所有值为连续值得下标\n",
    "numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index\n",
    "all_features[numeric_features] = all_features[numeric_features].apply(\n",
    "    lambda x: (x - x.mean()) / (x.std()))\n",
    "# 标准化后，每个特征的均值变为0，所以可以直接用0来替换缺失值，fillna（0）可以对缺失值进行填充\n",
    "all_features[numeric_features] = all_features[numeric_features].fillna(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来将离散数值转成指示特征。举个例子，假设特征MSZoning里面有两个不同的离散值RL和RM，那么这一步转换将去掉MSZoning特征，并新加两个特征MSZoning_RL和MSZoning_RM，其值为0或1。如果一个样本原来在MSZoning里的值为RL，那么有MSZoning_RL=1且MSZoning_RM=0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# dummy_na=True将缺失值也当作合法的特征值并为其创建指示特征\n",
    "all_features = pd.get_dummies(all_features, dummy_na=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "可以看到这一步转换将特征数从79增加到了331。\n",
    "最后，通过values属性得到NumPy格式的数据，并转成NDArray方便后面的训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = train_data.shape[0]\n",
    "train_features = np.array(all_features[:n_train].values,dtype=np.float)\n",
    "test_features = np.array(all_features[n_train:].values,dtype=np.float)\n",
    "train_labels = np.array(train_data.SalePrice.values.reshape(-1, 1),dtype=np.float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.4. 训练模型¶\n",
    "我们使用一个基本的线性回归模型和平方损失函数来训练模型,因为tensorflow中的keras中有对数均方根误差，故我们直接就使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_f=tf.keras.losses.mse\n",
    "def get_net():\n",
    "    net = keras.models.Sequential()\n",
    "    net.add(keras.layers.Dense(1))\n",
    "    return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对数均方根误差的实现如下，因为已经keras中已经集成了对数均方根误差，故直接调用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_rmse=tf.keras.losses.mean_squared_logarithmic_error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.5.  K 折交叉验证¶\n",
    "我们在“模型选择、欠拟合和过拟合”一节中介绍了 K 折交叉验证。它将被用来选择模型设计并调节超参数。下面实现了一个函数，它返回第i折交叉验证时所需要的训练和验证数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_k_fold_data(k, i, X, y):\n",
    "    assert k > 1\n",
    "    fold_size = X.shape[0] // k\n",
    "    X_train, y_train = None, None\n",
    "    for j in range(k):\n",
    "        idx = slice(j * fold_size, (j + 1) * fold_size)\n",
    "        X_part, y_part = X[idx, :], y[idx]\n",
    "        if j == i:\n",
    "            X_valid, y_valid = X_part, y_part\n",
    "        elif X_train is None:\n",
    "            X_train, y_train = X_part, y_part\n",
    "        else:\n",
    "            X_train = tf.concat([X_train, X_part], axis=0)\n",
    "            y_train = tf.concat([y_train, y_part], axis=0)\n",
    "    return X_train, y_train, X_valid, y_valid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 K 折交叉验证中我们训练 K 次并返回训练和验证的平均误差。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "def k_fold(k, X_train, y_train, num_epochs,\n",
    "           learning_rate, weight_decay, batch_size):\n",
    "    train_l_sum, valid_l_sum = 0, 0\n",
    "    for i in range(k):\n",
    "  # create model\n",
    "        data = get_k_fold_data(k, i, X_train, y_train)\n",
    "        net=get_net()\n",
    "    # Compile model\n",
    "        net.compile(loss=tf.keras.losses.mean_squared_logarithmic_error, optimizer=tf.keras.optimizers.Adam(learning_rate))\n",
    "    # Fit the model\n",
    "        history=net.fit(data[0], data[1],validation_data=(data[2], data[3]), epochs=num_epochs, batch_size=batch_size,validation_freq=1,verbose=0)\n",
    "        loss = history.history['loss']\n",
    "        val_loss = history.history['val_loss']\n",
    "        print('fold %d, train rmse %f, valid rmse %f'\n",
    "              % (i, loss[-1], val_loss[-1]))\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.plot(loss, label='train')\n",
    "    plt.plot(val_loss, label='valid')\n",
    "    plt.legend(loc='upper right')\n",
    "    plt.title('Training and Validation Loss')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.6. 模型选择¶\n",
    "我们使用一组未经调优的超参数并计算交叉验证误差。可以改动这些超参数来尽可能减小平均测试误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fold 0, train rmse 11.366129, valid rmse 11.477915\n",
      "fold 1, train rmse 11.860233, valid rmse 11.891942\n",
      "fold 2, train rmse 4.521535, valid rmse 4.599776\n",
      "fold 3, train rmse 5.571686, valid rmse 5.324833\n",
      "fold 4, train rmse 11.138485, valid rmse 11.130119\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMMAAAEICAYAAADiJ0BpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXt4VNXZt+9nJicCSYAECBAwICoIYjho8YQoohxE5RVE\nq9ZaK2qtqFXf8tn2rW1ta6tCq7UqVCoqghRU8EAREEQKiqAIQRDkHAiHBIGAnEKe74+9Q0fMYZJM\nZu9Mnvu65prZx/WbNfs3az17rb2WqCqGYUDAawGG4RfMDIbhYmYwDBczg2G4mBkMw8XMYBguvjSD\niARF5ICItI3kvl4iIh1EpFbuY598bhF5T0RurA0dIvIrEXmuusf7mYiYwb0YS18lInIoZLnMH6Ui\nVPW4qjZS1S2R3NeviMgcEfm/MtZfKyLbRCRYlfOp6uWqOjECui4TkU0nnft3qnpnTc9dRlo/FpH5\nkT5vVYiIGdyLsZGqNgK2AIND1n3nRxGRuEikG0NMAG4uY/3NwCuqejzKeuolUakmicijIvKaiEwS\nkSLgJhE5T0Q+EpG9IpIvIk+JSLy7f5yIqIhku8uvuNtnikiRiCwWkXZV3dfdPkBE1orIPhF5WkT+\nIyI/LEd3OBrvEJGvRORrEXkq5NigiIwRkUIR2QD0ryCLXgcyReT8kOPTgYHAS+7yVSKyXET2i8gW\nEflVBfm9sPQ7VabD/Ude7ebVehH5sbs+DXgLaBtSyjd3f8sXQ44fIiKr3Dx6X0TOCNmWJyI/E5GV\nbn5PEpHECvKhvO+TJSJvi8geEVknIj8K2dZLRD5182WniDzurk8WkVfd771XRJaISEaFCalqRF/A\nJuCyk9Y9ChwFBuMYsAFwDvA9IA5oD6wFfuruHwcokO0uvwIUAD2BeOA1nH/Mqu7bHCgCrna3/Qw4\nBvywnO8SjsbpQBqQDewp/e7AT4FVQBaQDixwsrvcfPsn8FzI8t3A0pDlS4HObv6d7X7HK91tHULP\nDSws/U6V6XB/k/aAuGkcArq62y4DNpXxW77ofu4EHHCPiwceBr4E4t3tecBHQKab9lrgx+V8/x8D\n88vZ9h/gaSAJ6O5+94vdbZ8AN7ifU4DvheTfmzjXWtC9HhpVdO1GM4BeqKpvqWqJqh5S1U9U9WNV\nLVbVDcBY4OIKjp+qqktV9RgwEcipxr5XAstVdbq7bQxOxpZJmBr/qKr7VHUTMD8kreuAMaqap6qF\nwGMV6AWnqnRdyD/nD9x1pVreV9VVbv59DkwuQ0tZVKjD/U02qMP7wFzgojDOC3A9MMPVdsw9dxrO\nH0gpf1HVHW7ab1Px7/Yd3FL9XGCUqh5W1U9x/jhKq5XHgNNEJF1Vi1T145D1GUAHdeLKpap6oKK0\nommGraELItJRRN4RkR0ish/4LY748tgR8vkboFE19m0VqkOdv5C88k4Spsaw0gI2V6AX4ANgPzBY\nRE4HugGTQrScJyLzRWS3iOzD+SetuNgPQ4eIXCkiH7tVkL3A5WGet/TcJ86nqiU4+dk6ZJ+q/G7l\npVGgqgdD1m0OSeNW4EzgS7cqNNBd/yIwB5gizk2IxyqLVaNphpNv5z0P5OI4NxX4P5yiujbJx6ku\nACAiwrd/uJOpicZ8oE3IcoW3fl1jvoRTItwMvKuqoaXWZGAa0EZV04B/hKmlXB0i0gCYCvwRaKGq\njYH3Qs5b2S3Y7cApIecL4OTvtjB0hct2IENEGoasa1uahqp+qarX41SBnwSmiUiSqh5V1UdUtRNw\nITAEqPDOppftDCnAPuCgiHQC7ohCmm8D3UVksPsvcS/QrJY0TgHuE5HWbjD88zCOeQknwP0RIVWk\nEC17VPWwiPTCqaLUVEcikADsBo6LyJVA35DtO3EuxJQKzn2ViPRxbyw8hBOTfVzO/pUREJGk0Jeq\nbgSWAn8QkUQRycEpDV4BEJGbRSTDLZX24Ri4REQuFZEurkH341SbSipMvJqiI8EDwC04mfc8TqBb\nq6jqTmA4MBooBE4FPgOO1ILGZ3Hq3ytxgrypYej7CliCc5G+c9Lmu4A/inM37mGcC7FGOlR1L3A/\n8AZO8D8U5w+jdHsuTmm0yb0j0/wkvatw8udZHEP1B65y44fqcBFOAB/6Auc3Ow2nyjUVeFhV57vb\nBgKr3Xx5Ahiuqkdxqlev4xhhFU6V6dWKEhetxw/3iNOYtR0Yqqofeq3H8BZfdseoTUSkv4g0du/a\n/Aqn+FzisSzDB9Q7M+AEUxtwivUrgCGqWl41yahH1OtqkmGEUh9LBsMok6h2mMvIyNDs7OxoJmnU\nI5YtW1agqhXdKq+QqJohOzubpUuXRjNJox4hIpW18leIVZMMw8XMYBguZgbDcLEnzmKEY8eOkZeX\nx+HDh72WUuskJSWRlZVFfHx8RM9rZogR8vLySElJITs7G6czbmyiqhQWFpKXl0e7du0qP6AKWDUp\nRjh8+DDp6ekxbQQAESE9Pb1WSkAzQwwR60Yopba+py/MMG7BBmat2lH5joZRi/jCDP/8z0bmfLHT\naxlGDdi7dy9///vfq3zcwIED2bt3by0oqjq+MEMwKBwvsQ6DdZnyzFBcXFzhce+++y6NGzeuLVlV\nwhd3k+ICAYrNDHWaUaNGsX79enJycoiPjycpKYkmTZqwZs0a1q5dyzXXXMPWrVs5fPgw9957LyNG\njAD+20XnwIEDDBgwgAsvvJBFixbRunVrpk+fToMGDaL2HSo1g4gk4Yy1k+juP1VVfy0ijwC34zwX\nAM6jeO9WR0RAsJIhgvzmrVV8sX1/RM95ZqtUfj24c7nbH3vsMXJzc1m+fDnz589n0KBB5Obmnrj9\nOX78eJo2bcqhQ4c455xzuPbaa0lPT//WOdatW8ekSZMYN24c1113HdOmTeOmm26K6PeoiHBKhiPA\npap6wH3oe6GIzHS3jVHVJ2osIhAwM8QY55577rfaAZ566ineeOMNALZu3cq6deu+Y4Z27dqRk+MM\nq9SjRw82bdoUNb0QhhncIUxKB1+Kd18RvXKDAbFqUgSp6B88WjRs+N+RXebPn8+cOXNYvHgxycnJ\n9OnTp8x2gsTE/448GQwGOXTo0Hf2qU3CCqDd8TqXA7uA2SGjlt0jIitEZLyINCnn2BEislRElu7e\nvbusXYgLCsdLKhzFw/A5KSkpFBUVlblt3759NGnShOTkZNasWcNHH30UZXXhEZYZ3OH5cnAGiDpX\nRLrgDA/SHme4wHycAZzKOnasqvZU1Z7NmpX93EVArGSo66Snp3PBBRfQpUsXHnrooW9t69+/P8XF\nxXTq1IlRo0bRq1cvj1RWTJXuJqnqXhGZB/QPjRVEZBwh4+1UWUTAbq3GAq++WvawRImJicycObPM\nbaVxQUZGBrm5uSfWP/jggxHXVxmVlgwi0kxEGrufGwD9gDUi0jJktyE4wzBWi6CZwfAB4ZQMLYEJ\n7oBbAWCKqr4tIi+7Q/0pzjD01R4eMi4oHDlmMYPhLeHcTVqBMyL0yevLmmmmWljMYPgBX3THsJjB\n8AO+MEPQGt0MH+ALM1jJYPgBX5jBaYG2ALo+0aiRM4HP9u3bGTp0aJn79OnTJ6rjbPnGDFYy1E9a\ntWrF1KmVTl0RFXxhhriAcNwGQK7TjBo1imeeeebE8iOPPMKjjz5K37596d69O2eddRbTp0//znGb\nNm2iS5cuABw6dIjrr7+eTp06MWTIkKj3TfLF8wzBgHD8uJkhYswcBTtWRvacmWfBgPInLB0+fDj3\n3Xcfd999NwBTpkxh1qxZjBw5ktTUVAoKCujVqxdXXXVVuc8wP/vssyQnJ7N69WpWrFhB9+7dI/sd\nKsEXZogLWjtDXadbt27s2rWL7du3s3v3bpo0aUJmZib3338/CxYsIBAIsG3bNnbu3ElmZmaZ51iw\nYAEjR44EoGvXrnTt2jWaX8EfZgiIxQwRpYJ/8Npk2LBhTJ06lR07djB8+HAmTpzI7t27WbZsGfHx\n8WRnZ/t6kDOLGYyIMXz4cCZPnszUqVMZNmwY+/bto3nz5sTHxzNv3jw2b654kOzevXuf6OyXm5vL\nihUroiH7BL4oGYKBgMUMMUDnzp0pKiqidevWtGzZkhtvvJHBgwdz1lln0bNnTzp27Fjh8XfddRe3\n3nornTp1olOnTvTo0SNKyh18YQaLGWKHlSv/G7hnZGSwePHiMvc7cMB5eDI7O/tE1+0GDRowefLk\n2hdZDr6oJlnMYPgBX5jBYgbDD/jCDKUt0DbzaM2oL/lXW9/TF2aICziNMFZVqj5JSUkUFhbGvCFK\nh6RPSkqK+Ll9EUAHXDMUlyhxQY/F1FGysrLIy8ujvBFIYonSyUoijS/MYCVDzYmPj4/45B31DV9U\nk4KlZojxIt7wN74ww4mSwRreDA/xhRmCITGDYXiFT8zgyLCYwfASX5ghzmIGwwf4wgxBixkMH+Ar\nM9igAIaX+MoMFjMYXuILM1jMYPgBX5jhRDXJYgbDQ3xlBqsmGV7iKzNYo5vhJb4wQ5zb6FZiMYPh\nIeHM3JMkIktE5HMRWSUiv3HXNxWR2SKyzn0vc4LDcLCYwfAD4ZQMpfNAn40zmWF/EekFjALmqupp\nwFx3uVpYzGD4gUrNoA5lzQN9NTDBXT8BuKa6IqzRzfADNZkHuoWq5ru77ABalHNs5fNAu2awmMHw\nkprMAx26XXFKi7KOrXQeaIsZDD9QpbtJqroXmAf0B3aWTn/rvu+qroi4oMUMhvdUex5oYAZwi7vb\nLcB3B98Pk6BYO4PhPTWZB3oxMEVEbgM2A9dVV0TQYgbDB9RkHuhCoG9ERLiNbhYzGF7iixbooMUM\nhg/whxksZjB8gD/McKIF2hrdDO/whRlsRD3DD/jCDKUxg1WTDC/xhxnESgbDe/xhBnu4x/ABvjDD\niY56ZgbDQ3xhBisZDD/gCzOICAGxmMHwFl+YAZwuGVYyGF7iGzMEA2Id9QxP8Y0Z4gJiHfUMT/GN\nGQIBse4Yhqf4xgxxAbGYwfAU35jBYgbDa3xjBosZDK/xjRmcmMHMYHiHb8xgMYPhNb4xQzAgNlmJ\n4Sm+MUNcIGATHBqe4hszBKyaZHiMb8wQZ41uhsf4xgxBKxkMj/GNGeKs0c3wGN+YIWCNbobH+MYM\ncdboZniMb8xgMYPhNb4xg8UMhtf4xgzBQMBiBsNTfGQGGxDA8JZwZu5pIyLzROQLdx7oe931j4jI\nNhFZ7r4G1kSIMyCANboZ3hHOzD3FwAOq+qmIpADLRGS2u22Mqj4RCSHOwz2ROJNhVI9wZu7JB/Ld\nz0UishpoHXEhAbGSwfCUKsUMIpKNM6XVx+6qe0RkhYiMF5Em5RxT6TzQ4D7cYwG04SFhm0FEGgHT\ngPtUdT/wLNAeyMEpOZ4s67hw5oFm4jAu3z3B2hkMTwnLDCISj2OEiar6OoCq7nQnSy8BxgHnVlvF\n3q1kHV5r7QyGp4RzN0mAF4DVqjo6ZH3LkN2GALnVVpGWReNju6xkMDwlnLtJFwA3AytFZLm77mHg\nBhHJARTYBNxRbRVpWaRt+sRiBsNTwrmbtBCQMja9GzEVaVk0LN5LXMnhiJ3SMKqKP1qg09oA0EwL\nPBZi1Gd8YoYsAFqYGQwP8ZUZMilA7Y6S4RH+MENqKxShtRRw6Nhxr9UY9RR/mCEYz5EGzWlFIWt2\nFHmtxqin+MMMgDRuQyspIHfbPq+lGPUU35ghoWkb2gT3sDLPzGB4g2/MIGlZtKSAVXl7vJZi1FN8\nYwZa9ySBYzQt+ITDFkQbHuAfM5x2OcVxyQySRazO3++1GqMe4h8zJCRztMMABgSX8Mn6nV6rMeoh\n/jEDkNztOhrLQTYsftMGBzCijq/MwKmXcqhBS4Yensa81VY6GNHFX2aISyDhkgfpGVjLkrlTvVZj\n1DP8ZQYg2P1mDiRmMrBgPB+u3eW1HKMe4TszEJdIYr+HyQmsZ+GMFyix2MGIEv4zAxDf/Sb2p3Tg\n+v3/5K3PNnktx6gn+NIMBII0uvIPtAvsZP3Mpzl01BrhjNrHn2YAAqdfzr7M87n12Gu89P7yyg8w\njBriWzMgQto1fyZNDpKw6Eny9x3yWpER4/jXDACZZ/FN5xu4Sf7N+Ddnea3GiHH8bQag0YDfcjyu\nARd8NZqP19sz0kbt4Xsz0KgZgUsepk/wc2ZOe4Fjx21wYqN28L8ZgITz7qAo7QxuP/g8Ez9c7bUc\nI0apE2YgGEej//krraWQ4/P+yLa9FkwbkadumAGQU87jQOfvcwvv8I8pb9qQMkbEqTNmAGg06Pcc\nS0jj6rw/M/PzPK/lGDFGnTIDyU1JuPJxcgIbWDvjT+z95qjXiowYom6ZAQh2Hcr+Uy7nzuOT+ftU\na3swIkedMwMipA59GuIbcMVXv2F27navFRkxQt0zA0BKJnGDHqdHYB1fvP4HCg8c8VqREQPUZB7o\npiIyW0TWue9lTnBYW8TlDGd/dn/uOj6Jv0+ebneXjBoTTslQOg/0mUAv4G4RORMYBcxV1dOAue5y\n9BAhddgzFCemMXTLb5n28fqoJm/EHpWaQVXzVfVT93MRUDoP9NXABHe3CcA1tSWyXBpmkHTts3QK\nbOXwzF+wYfeBqEswYoeazAPdwp0wHWAH0CKiysIkcMYVHOx2OzfJv3npxedsND6j2tRkHugTqFNh\nL7PSHu6k6DWh4aDfs7/xmdx7YAx/e3N+raRhxD7Vngca2Fk6/a37XuZQFmFNil5T4hJJvfkVkoMl\n9Fk5ineWb66ddIyYptrzQAMzgFvcz7cA0yMvrwqkn0rwmqfpGVhL4Ruj+GqXxQ9G1QinZCidB/pS\nEVnuvgYCjwH9RGQdcJm77ClxXYdysNvt/EDeZfL40RQdPua1JKMOUZN5oAH6RlZOzWl45R/Zv205\nD+x8msdfPo1f3nY9gUB58g3jv9TNFuiKCMaT+oNXKUlqwm15v2Dsu4u9VmTUEWLPDACNmpN8yxSa\nBQ5yzpKRvLVsg9eKjDpAbJoBkFY5yP88T4/AOph+N0s3FnotyfA5MWsGgPizruHQRb9gcGARn014\niI0FB72WZPiYmDYDQINLH6LozBu4nWn8a+yj7C6yHq5G2cS8GRAh5dqn2ZfVhweOPMvzY//KgSPF\nXqsyfEjsmwEgGE/aD17lQEZXHtr/Z8aM/Yf1YTK+Q/0wA0BCQ9Jue5MjqW25v+DXPPnPSTYgmfEt\n6o8ZAJKbknr7O2hyOj/d9hBPTJhqEykaJ6hfZgBIbUnKiJkEklK5Y/P9jH75dTOEAdRHMwA0OYWU\nEe+SkJjMbRtGMublaWYIo56aASD9VBrd8W/iExty24Z7Gf3SFIothqjX1F8zAKSfSspd7xFMSuGO\njfcxevzLHCm2u0z1lfptBoAm2aT+ZA7aMIOf5j3EX55/nm+OWjtEfcTMAJCWRdpdcziaegr37/ol\nz/ztcRu6sh5iZiglpQWNfzKbAxk5PLDvT7zy14dt6Pt6hpkhlAaNaXrnO+xt24+fHhnH+0+NYNW2\nr71WZUQJM8PJxDeg6a2T+brzLdxcMoO8scOZn2sDDNQHzAxlEQjSZOhfKer9CP1kCY2nDGHS3CU2\nhGWMY2YoDxFSLr2fY9dOoGNwG30WDOeZif/iaLG1RcQqZoZKSDzrahJun01yYgI/XvcTnnv6D/ZM\nRIxiZgiDQKuupN37Hw40y2Hkvsd5f8ytfLqpzDHTjDqMmSFcGmaQcddMCrvcxvCSdygZfyX/mmdx\nRCxhZqgKwXjSh47mm8Fj6RLczCXzr+WZf4yzwcpiBDNDNUjuMZyEOz9AGjbjJ3n/y/Qn72DllgKv\nZRk1xMxQTQItOpJ+30IKTh/GTcemUfyP/kye9SEl1hW8zmJmqAkJyTS/cRwHBo+lY3A7gxYNY+zf\nfs8O68ZRJzEzRIBGPYaTNHIRh5p24s49j7PiL9cw65NVXssyqoiZIUJIk2ya3zOHwvMe5hKW0v3t\n/owd+zRfH7Ter3UFM0MkCQRJv+LnyIh50CiTEdt/ycdPDGHuMisl6gJmhlogrlVXmv1sEbt6PEBf\n/YicGVcw/rk/s2u/xRJ+xsxQWwTjaT74/+CODyhObcuPdvyeNaMHMOODj+yOk08JZxqr8SKyS0Ry\nQ9Y9IiLbTprJxyiD+JZdaHH/hxRc+Ajnymoue/8qJo55gLXb93gtzTiJcEqGF4H+Zawfo6o57uvd\nyMqKMQJBMi67n4SRn/B15nncXPQC8tyFTJz8io376iPCmRR9AWB/YxEg0KQtre+aTtGQV0hPLOHG\nNXez6E9XMWvxMuvj5ANqEjPcIyIr3GpUk/J2isY80HWNlLMH0/R/P2N7t/u4uOQTLvr3ACY/OZLc\nTTu9llavkXD+kUQkG3hbVbu4yy2AApyJ0H8HtFTVH1V2np49e+rSpUtrojfmOL5nE9unPEibHbPZ\nphnMa/MT+g27ixZpyV5Lq3OIyDJV7Vnd46tVMqjqTlU9rqolwDjg3OoKqO8Em2bT5s6pHLzhTYIN\nm3JT3m/ZNfoCpvxrksUTUaZaZhCRliGLQ4Dc8vY1wqPhGZeQ+eDHFF72V7Lii7hu1Z18+tjlzHhv\ntj1qGiUqrSaJyCSgD5AB7AR+7S7n4FSTNgF3qGp+ZYlZNSlMjh1i26y/kLbsbySXHGR2XG/kklH0\nPf88gjandbnUtJoUVswQKcwMVUO/2cOWGX+gxZqXCGoxsxP6ktxvFBef0x0RM8XJmBnqASX78tk0\n/XdkbXgNVJmddAUp/X7ORT3ONlOEYGaoRxTv2cKW6b+jzebXKVGYm3gZDfs+xEXn9CBg1SczQ32k\neM9mNk//PW02T0NUmZfQh2Dv+7n4/AuIC9bf7mZmhnpM8ddb2fzWn2i94TUS9BgL4npx6Nx7uOTS\n/iTFB72WF3XMDAYlRbvZ+M4TtPjyZRrpQT6RLuzoPILeA24grWGC1/KihpnBOIEe3sfm954l9fNx\nND1ewDrN4ovsH9B90AjaNC+3x0zMYGYwvkvxUbYtfAVZ/DStjmxgt6bxUfoQ2vS7m7M7nhazd6DM\nDEb5qLJn5Xvsef8vdNi7iCMax8Kki+HcO7jw4stIjIutuMLMYITFofzVbJk5hrZbptOAw6zgdLac\n+n26DfghrTNiowplZjCqhB7ay8Y540he8SKZx/Io1BSWNB5I44tu53vde9bp9gozg1E9SkrYvWIW\nX3/wHO2/XkAcJSwJ5FDY8fv06Pd9mjdJ8VphlTEzGDXm6J48Nsx+lmZfTia9pIACTWVZ4/6knv8j\nzun5vTrTkGdmMCJHyXHyl73N/kUvcOrXHxJHCZ9LR7a3v5ZOfW8hu1ULrxVWiJnBqBWO7s1n45xx\npK6ZQsvirXyjiSxpcCF69g306HMVqQ0SvZb4HcwMRu2iSuGahexY8AKn5M+iEd+wXdNZlX4Fqb1u\nokeP83xTjTIzGFFDj37DpkVTObpsIqcWLSGOElbTni2tB9H6opvpfMbpnjbomRkMTziyN5+N818i\nafU0so98SYkKnwW7UNhuMB0uvoH2bdtGXZOZwfCcorwv2PzBBNI3vkXL4m0c0yCfxXdjX4er6Hjx\n9bRpGZ3A28xg+AdV9ny1hG0LXyFz60yalezmiMbzWUJ3DnYYxBm9h5PVMrPWkjczGP6kpIRdaxay\nY9EkWm5/j2YlBRzROFYk5FDUbiDtLhpGuzaRrUqZGQz/U1LCztUL2fnRa2Rue4/mJbso1gAr47qw\np+3lZPUayumnd6xx8G1mMOoWquxe+zHbP5pC+tbZZBVvAWC1dCA/8xKadh9C5269iK9Gj1ozg1Gn\n+XpzLlsXT6Xhxn9z6pHVAOTRnPVNLiLhzEF0Pr8/qQ0bhnUuM4MRM3yzJ48NC6cRWDuT9geWkcRR\nirQBq5J7cqTdZWT3GsIpbU8p93gzgxGTHD98gA1L3uHgyndoU/Ah6bqHEhW+DHYg+4F5NGj43V61\nNTVDXI0UG0YtEUxqxGm9h0Pv4VBSwo61S8hfOp1A4VdlGiESmBkM/xMIkNmxF5kde9VuMrV6dsOo\nQ5gZDMPFzGAYLmYGw3Cp7jzQTUVktoisc99jY6wRo15T3XmgRwFzVfU0YK67bBh1murOA301MMH9\nPAG4JsK6DCPqVDdmaBEyh9sOoNynN2weaKOuUONGN1VVESm3T4eqjgXGAojIbhHZXM6uGThzS/sB\n01I2ftdSfselMKiuGXaKSEtVzXenwd0VzkGq2qy8bSKytCb9SiKJaSmbWNdS3WrSDOAW9/MtwPTI\nyDEM7wjn1uokYDFwhojkichtwGNAPxFZB1zmLhtGnabSapKq3lDOpr4R1jI2wuerCaalbGJaS1Sf\nZzAMP2PdMQzDxcxgGC6em0FE+ovIlyLylYhEtVuHiLQRkXki8oWIrBKRe931j4jINhFZ7r4GRknP\nJhFZ6aa51F0X9X5gInJGyHdfLiL7ReS+aOVLVfvDicj/c6+fL0XkimonrKqevYAgsB5oDyQAnwNn\nRjH9lkB393MKsBY4E3gEeNCD/NgEZJy07s/AKPfzKOBPHvxGO3AatKKSL0BvoDuQW1k+uL/X50Ai\n0M69noLVSdfrkuFc4CtV3aCqR4HJOP2eooKq5qvqp+7nImA10Dpa6YeJ1/3A+gLrVbW8ngMRR6vW\nH+5qYLKqHlHVjcBXONdVlfHaDK2BrSHLeXh0MYpINtAN+NhddY+IrHCL7Gh1UVdgjogsE5ER7rqw\n+4HVEtcDk0KWvcgXKD8fInYNeW0GXyAijYBpwH2quh94FqfqlgPkA09GScqFqpoDDADuFpHeoRvV\nqRdE7V64iCQAVwH/cld5lS/forbywWszbAPahCxnueuihojE4xhhoqq+DqCqO1X1uKqWAOOoZrFb\nVVR1m/u+C3jDTXen2/+LqvQDixADgE9Vdaery5N8cSkvHyJ2DXlthk+A00SknfsvdD1Ov6eoIM5I\nty8Aq1V1dMj6liG7DQFyTz62FrQ0FJGU0s/A5W66XvYDu4GQKpIX+RJCefkwA7heRBJFpB1wGrCk\nWilE+45JGXcOBuLcxVkP/CLKaV+IU9yuAJa7r4HAy8BKd/0MoGUUtLTHuSvyObCqNC+AdJynCdcB\nc4CmUcoZvcxbAAAATUlEQVSbhkAhkBayLir5gmPAfOAYTgxwW0X5APzCvX6+BAZUN13rjmEYLl5X\nkwzDN5gZDMPFzGAYLmYGw3AxMxiGi5nBMFzMDIbh8v8Bj8lgQOYuEokAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x523bc91ef0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64\n",
    "k_fold(k, train_features, train_labels, num_epochs, lr,\n",
    "                          weight_decay, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.16.7. 预测并在Kaggle提交结果¶\n",
    "下面定义预测函数。在预测之前，我们会使用完整的训练数据集来重新训练模型，并将预测结果存成提交所需要的格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train=tf.convert_to_tensor(train_features,dtype=tf.float32)\n",
    "y_train=tf.convert_to_tensor(train_labels,dtype=tf.float32)\n",
    "x_test=tf.convert_to_tensor(test_features,dtype=tf.float32)\n",
    "model=tf.keras.models.Sequential([\n",
    "  tf.keras.layers.Dense(1)\n",
    "])\n",
    "adam=tf.keras.optimizers.Adam(0.5)\n",
    "model.compile(optimizer=adam,\n",
    "              loss=tf.keras.losses.mean_squared_logarithmic_error\n",
    "              )\n",
    "model.fit(x_train, y_train, epochs=200,batch_size=32,verbose=0)\n",
    "preds=np.array(model.predict(x_test))\n",
    "test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
